import argparse
from argparse import Action


def default_argument_parser():
    parser = argparse.ArgumentParser('SSL')
    parser.add_argument('-expn', '--experiment-name', type=str, default='ImageNet1K')
    parser.add_argument('--ws', action='store_true', default=False)
    parser.add_argument('--single-aug', action='store_true', default=False)
    parser.add_argument(
        "--exp-options",
        nargs="+",
        action=DictAction,
        help="override some settings in the exp, the key-value pair in xxx=yyy format will be merged into exp. "
             'If the value to be overwritten is a list, it should be like key="[a,b]" or key=a,b '
             'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
             "Note that the quotation marks are necessary and that no white space is allowed.",
    )
    parser.add_argument('--num-classes', type=int, default=100)
    parser.add_argument('--mode', type=str, default='fixnorm')
    parser.add_argument('--last_relu', default=False, action='store_true')
    parser.add_argument('--loss', default='sm', type=str)
    parser.add_argument('--alpha', default=0.0, type=float)
    parser.add_argument('--adjusted', default=False, action='store_true')
    parser.add_argument('--eps', default=0.8, type=float)
    parser.add_argument('--reg', default=0.0, type=float)

    # save
    parser.add_argument('--output-dir', type=str, default='outputs')
    parser.add_argument('--resume', default=False, action='store_true')
    parser.add_argument('--no-oss-saved', default=False, action='store_true')

    # optimization
    parser.add_argument('-bs', '--batch-size', type=int, default=256, help='batch size')
    parser.add_argument('--start-epoch', type=int, default=0)
    parser.add_argument('--total-epochs', type=int, default=100)
    parser.add_argument('--warmup-epochs', type=int, default=10)

    # distributed
    parser.add_argument('--deepspeed', default=False, action='store_true', help='use DeepSpeed')
    parser.add_argument('--dist-backend', default='nccl', type=str, help='distributed backend')
    parser.add_argument('--dist-url', default=None, type=str, help='url used to set up distributed training')
    # parser.add_argument('-f', '--exp_file', default=None, type=str, help='pls input your expriment description file')
    parser.add_argument('--rank', default=None, type=int, help='node rank for distributed training')
    parser.add_argument('--debug', default=False, action='store_true')
    parser.add_argument('--gpu', default=None, type=int, help='GPU id to use.')
    parser.add_argument('--seed', default=None, type=int, help='seed for initializing training. ')
    parser.add_argument('--num-machines', default=None, type=int)

    return parser


class DictAction(Action):
    """
    argparse action to split an argument into KEY=VALUE form
    on the first = and append to a dictionary. List options can
    be passed as comma separated values, i.e 'KEY=V1,V2,V3', or with explicit
    brackets, i.e. 'KEY=[V1,V2,V3]'. It also support nested brackets to build
    list/tuple values. e.g. 'KEY=[(V1,V2),(V3,V4)]'
    """

    @staticmethod
    def _parse_int_float_bool(val):
        try:
            return int(val)
        except ValueError:
            pass
        try:
            return float(val)
        except ValueError:
            pass
        if val.lower() in ["true", "false"]:
            return True if val.lower() == "true" else False
        return val

    @staticmethod
    def _parse_iterable(val):
        """Parse iterable values in the string.
        All elements inside '()' or '[]' are treated as iterable values.
        Args:
            val (str): Value string.
        Returns:
            list | tuple: The expanded list or tuple from the string.
        Examples:
            >>> DictAction._parse_iterable('1,2,3')
            [1, 2, 3]
            >>> DictAction._parse_iterable('[a, b, c]')
            ['a', 'b', 'c']
            >>> DictAction._parse_iterable('[(1, 2, 3), [a, b], c]')
            [(1, 2, 3), ['a', 'b], 'c']
        """

        def find_next_comma(string):
            """Find the position of next comma in the string.
            If no ',' is found in the string, return the string length. All
            chars inside '()' and '[]' are treated as one element and thus ','
            inside these brackets are ignored.
            """
            assert (string.count("(") == string.count(")")) and (
                string.count("[") == string.count("]")
            ), f"Imbalanced brackets exist in {string}"
            end = len(string)
            for idx, char in enumerate(string):
                pre = string[:idx]
                # The string before this ',' is balanced
                if (char == ",") and (pre.count("(") == pre.count(")")) and (pre.count("[") == pre.count("]")):
                    end = idx
                    break
            return end

        # Strip ' and " characters and replace whitespace.
        val = val.strip("'\"").replace(" ", "")
        is_tuple = False
        if val.startswith("(") and val.endswith(")"):
            is_tuple = True
            val = val[1:-1]
        elif val.startswith("[") and val.endswith("]"):
            val = val[1:-1]
        elif "," not in val:
            # val is a single value
            return DictAction._parse_int_float_bool(val)

        values = []
        while len(val) > 0:
            comma_idx = find_next_comma(val)
            element = DictAction._parse_iterable(val[:comma_idx])
            values.append(element)
            val = val[comma_idx + 1 :]
        if is_tuple:
            values = tuple(values)
        return values

    def __call__(self, parser, namespace, values, option_string=None):
        options = {}
        for kv in values:
            key, val = kv.split("=", maxsplit=1)
            options[key] = self._parse_iterable(val)
        setattr(namespace, self.dest, options)
